package com.personal.jdbc.operator;

import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import javax.sql.DataSource;

import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.datasource.DataSourceUtils;

import com.personal.core.data.DataColumn;
import com.personal.core.data.DataRow;
import com.personal.core.data.DataTable;
import com.personal.core.data.Page;
import com.personal.core.utils.Assert;
import com.personal.core.utils.CoreUtil;
import com.personal.core.utils.LogUtil;
import com.personal.jdbc.dialect.JdbcDialect;
import com.personal.jdbc.exception.JdbcException;

/**
 * Jdbc操作
 * 
 * @author qq
 *
 */
public class JdbcOperator
{
    private JdbcTemplate template;

    private JdbcDialect dialect;
    
    private static final int BATCH_COUNT = 2000;

    public JdbcOperator(DataSource dataSource)
    {
        super();
        this.template = new JdbcTemplate(dataSource);
        this.getDialect(dataSource);
    }

    /**
     * 查询数据
     * @param sql
     * @param values
     * @return
     */
    public List<Map<String, Object>> queryForList(String sql, Object... values)
    {
        return this.template.queryForList(sql, values);
    }

    /**
     * 查询分页数据
     * @param sql
     * @param start
     * @param limit
     * @param values
     * @return
     */
    public Page<Map<String, Object>> queryForPage(String sql, int start, int limit, Object... values)
    {
        List<Map<String, Object>> list = this.template.queryForList(this.dialect.makePagging(sql, start, limit),
                values);
        Page<Map<String, Object>> result = new Page<>();
        result.setData(list);
        result.setStart(start);
        result.setLimit(limit);
        return result;
    }

    /**
     * 查询分页数据和总行数
     * @param sql
     * @param start
     * @param limit
     * @param values
     * @return
     * @throws SQLException
     */
    public Page<Map<String, Object>> queryForPageAndTotalCount(String sql, int start, int limit, Object... values)
            throws SQLException
    {
        Page<Map<String, Object>> result = queryForPage(sql, start, limit, values);
        if (result == null)
        {
            return null;
        }
        result.setTotalCount(this.template.queryForObject(this.dialect.makeCount(sql), Long.class, values));
        return result;
    }

    /**
     * 执行非查询
     * @param sql
     * @param values
     * @return
     */
    public int executeNonQuerySQL(String sql, Object... values)
    {
        return this.template.update(sql, values);
    }

    /**
     * 创建Blob字段
     * @param data
     * @return
     */
    public Blob createBlob(byte[] data)
    {
        Connection connection = null;
        Blob blob = null;
        try
        {
            connection = DataSourceUtils.getConnection(this.template.getDataSource());
            blob = connection.createBlob();
            blob.setBytes(data.length, data);
            return blob;
        } catch (SQLException e)
        {
            LogUtil.log(e);
            throw new JdbcException(e.getMessage(), e);
        } finally
        {
            this.releaseResource(connection, (PreparedStatement) null, (ResultSet) null);
        }
    }

    /**
     * 查询第一条记录
     * @param sql
     * @param values
     * @return
     */
    public Map<String, Object> queryFirstRecord(String sql, Object... values)
    {
        List<Map<String, Object>> list = queryForList(sql, values);
        return (list == null || list.isEmpty()) ? null : list.get(0);
    }

    /**
     * 查询sql记录的第一条记录的第一个字段值
     * @param sql
     * @param t
     * @param values
     * @return
     */
    public <T> T queryFirstRecordValue(String sql, Class<T> t, Object... values)
    {
        return this.template.query(sql, values, new ResultSetExtractor<T>()
        {
            @Override
            public T extractData(ResultSet rs) throws SQLException, DataAccessException
            {
                if (rs.next())
                {
                    return rs.getObject(1, t);
                }
                return null;
            }
        });
    }

    /**
     * 查询sql记录的最后一条记录的第一个字段值
     * @param sql
     * @param t
     * @param values
     * @return
     */
    public <T> T queryLastRecordValue(String sql, Class<T> t, Object... values)
    {
        return this.template.query(sql, values, new ResultSetExtractor<T>()
        {
            @Override
            public T extractData(ResultSet rs) throws SQLException, DataAccessException
            {
                while (rs.next())
                {
                    if (rs.last())
                    {
                        return rs.getObject(1, t);
                    }
                }
                return null;
            }
        });
    }

    /**
     * 查询总行数
     * @param sql
     * @param values
     * @return
     */
    public long queryDataCount(String sql, Object... values)
    {
        return this.template.queryForObject(this.dialect.makeCount(sql), Long.class, values);
    }

    /**
     * 批量插入DataTable
     * @param dataTable
     * @return
     */
    public int batchInsert(DataTable dataTable)
    {
        if (!CoreUtil.checkDataTableHasData(dataTable))
        {
            return 0;
        }
        StringBuilder before = new StringBuilder();
        StringBuilder after = new StringBuilder(" (");
        before.append("insert into ").append(dataTable.getTableName()).append(" (");
        for (DataColumn column : dataTable.getColumns())
        {
            before.append(column.getColumnName()).append(",");
            after.append(" ?,");
        }
        after.deleteCharAt(after.length() - 1).append(" )");
        before.deleteCharAt(before.length() - 1).append(" ) values ").append(after);
        List<Object[]> list = new ArrayList<Object[]>();
        for (DataRow row : dataTable.getRows())
        {
            Object[] arr = new Object[dataTable.getColumns().size()];
            int colIndex = 0;
            for (DataColumn column : dataTable.getColumns())
            {
                arr[colIndex++] = row.getValue(column);
            }
            list.add(arr);
        }
        return batchUpdate(before.toString(), list);
    }

    /**
     * 批量更新Sql
     * @param sqls
     * @return
     */
    public int batchUpdateUseJdbcTemplate(String... sqls)
    {
        int[] arr = this.template.batchUpdate(sqls);
        int result = 0;
        for (int i : arr)
        {
            result += i;
        }
        return result;
    }

    /**
     * 批量更新Sql
     * @param sqls
     * @return
     */
    public int batchUpdate(String... sqls)
    {
        Connection connection = null;
        Statement statement = null;
        try
        {
            connection = getConnection();
            begainTransaction(connection);
            statement = connection.createStatement();
            for (int i = 0; i < sqls.length; i++)
            {
                statement.addBatch(sqls[i]);
                if ((i+1) % BATCH_COUNT == 0)
                {
                    statement.executeBatch();
                    statement.clearBatch();
                }
            }
            statement.executeBatch();
            statement.clearBatch();
            connection.commit();
            return sqls.length;
        } catch (SQLException e)
        {
            LogUtil.log(e);
            rollbackTransaction(connection);
            return 0;
        } finally
        {
            releaseResource(connection, statement, null);
        }
    }

    /**
     * 批量更新Sql
     * @param sqls
     * @return
     */
    public int batchUpdateUseJdbcTemplate(String sql, List<Object[]> batchArgs)
    {
        int[] arr = this.template.batchUpdate(sql, batchArgs);
        int result = 0;
        for (int i : arr)
        {
            result += i;
        }
        return result;
    }
    
    /**
     * 批量更新Sql
     * @param sqls
     * @return
     */
    public int batchUpdate(String sql, List<Object[]> batchArgs)
    {
        Connection connection = null;
        PreparedStatement statement = null;
        try
        {
            connection = getConnection();
            begainTransaction(connection);
            statement = connection.prepareStatement(sql);
            for (int i = 0; i < batchArgs.size(); i++)
            {
                Object[] objArr = batchArgs.get(i);
                for (int j = 0; j < objArr.length; j++)
                {
                    statement.setObject(j + 1, objArr[j]);
                }
                statement.addBatch();
                if ((i+1) % BATCH_COUNT == 0)
                {
                    statement.executeBatch();
                    statement.clearBatch();
                }
            }
            statement.executeBatch();
            statement.clearBatch();
            connection.commit();
            return batchArgs.size();
        } catch (SQLException e)
        {
            LogUtil.log(e);
            rollbackTransaction(connection);
            return 0;
        } finally
        {
            releaseResource(connection, statement, null);
        }
    }

    /**
     * 查询结果作为DataTable
     * @param sql
     * @param values
     * @return
     * @throws Exception
     */
    public DataTable queryForDataTable(String sql, Object... values) throws Exception
    {
        return this.template.query(sql, values, new ResultSetExtractor<DataTable>()
        {
            @Override
            public DataTable extractData(ResultSet rs) throws SQLException, DataAccessException
            {
                ResultSetMetaData metaData = rs.getMetaData();
                int colCount = metaData.getColumnCount();
                DataTable result = new DataTable();
                for (int i = 0; i < colCount; i++)
                {
                    DataColumn column = new DataColumn();
                    column.setColumnLable(metaData.getColumnLabel(i + 1));
                    column.setColumnName(metaData.getColumnName(i + 1));
                    column.setSqlType(metaData.getColumnType(i + 1));
                    result.getColumns().add(column);
                }
                while (rs.next())
                {
                    DataRow newRow = result.addNewRow();
                    int colIndex = 1;
                    for (DataColumn column : result.getColumns())
                    {
                        newRow.setValue(column, rs.getObject(colIndex++));
                    }
                }
                return result;
            }
        });
    }

    /**
     * 获取连接
     * @return
     */
    public Connection getConnection()
    {
        return DataSourceUtils.getConnection(this.template.getDataSource());
    }
    
    /**
     * 开启事务
     * @param con
     */
    public void begainTransaction(Connection con)
    {
        Assert.isNotNull(con, "连接不能为空！");
        try
        {
            con.setAutoCommit(false);
        } catch (SQLException e)
        {
            LogUtil.log(e);
        }
    }
    
    /**
     * 提交事务
     * @param con
     */
    public void commitTransaction(Connection con)
    {
        Assert.isNotNull(con, "连接不能为空！");
        try
        {
            con.commit();
        } catch (SQLException e)
        {
            LogUtil.log(e);
        }
    }
    
    /**
     * 回滚事务
     * @param con
     */
    public void rollbackTransaction(Connection con)
    {
        Assert.isNotNull(con, "连接不能为空！");
        try
        {
            con.rollback();
        } catch (SQLException e)
        {
            LogUtil.log(e);
        }
    }
    
    private void releaseResource(Connection conn, Statement st, ResultSet rs)
    {
        if (rs != null)
        {
            try
            {
                rs.close();
            } catch (SQLException e)
            {
                LogUtil.log(e);
            }
        }

        if (st != null)
        {
            try
            {
                st.close();
            } catch (SQLException e)
            {
                LogUtil.log(e);
            }
        }
        DataSourceUtils.releaseConnection(conn, this.template.getDataSource());
    }

    private void getDialect(DataSource dataSource)
    {
        Connection connection = null;
        try
        {
            connection = DataSourceUtils.getConnection(dataSource);
            String e = connection.getMetaData().getDatabaseProductName();
            this.dialect = JdbcDialect.getDialect(e);
        } catch (SQLException e)
        {
            LogUtil.log(e);
            throw new JdbcException(e.getMessage(), e);
        } finally
        {
            this.releaseResource(connection, (PreparedStatement) null, (ResultSet) null);
        }

    }

    public JdbcTemplate getTemplate()
    {
        return template;
    }

    public JdbcDialect getDialect()
    {
        return dialect;
    }
    
}
